import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
from SGDMomentum import run_momentum_batch
from volterraPsi import volterra_Psi


'''========================================================
            Generating concentration plots
                    (Figure 1 in paper)
            
Hyperparameters:
    float R_tilde       :  constant normalization for noise
    float R             :  constant normalization for signal
    float Delta         :  momentum parameter
    float learning_rate :  learning rate
    int   max_iter      :  # of iterations to run SGD+Momentum
    int   beta          :  batch size of SGD+Momentum
    float zeta          :  beta/n
==========================================================='''


## plotting
R_tilde = 1
R=1
Delta = 0.5
learning_rate = 0.4
max_iter = 100
zeta = 0.5

def samplelossCurve(num_trials, n, d, zeta):
  '''
  @param num_trials:                 number of sample loss curves generated
  @param n:                          number of rows in matrix A
  @param d:                          number of columns in matrix A
  @return sample_curves:             List of lists of containing generated sample loss curves
  '''
  sample_curves = []
  x_tilde = np.zeros(shape=(n))
  x_0 = np.random.randn(d) * (np.sqrt(R)/np.sqrt(n))

  
  for i in range(num_trials):
    A = np.random.randn(n,d)
    invnorms = sp.sparse.diags(1/np.linalg.norm(A,axis=1))
    b = np.random.randn(n)
    A = invnorms@A
    b = invnorms@b * (np.sqrt(R_tilde)*np.sqrt(d)/np.sqrt(n))
    (_, curr_losscurve) = run_momentum_batch(A=A, b=b,x=x_0,n=n,d=d,max_iter=max_iter, batch_size = int(zeta * n), learning_rate=learning_rate,
                          Delta=Delta, loss_history = [])
    sample_curves.append(curr_losscurve)
  return sample_curves



'''
dictionary curve_stats
--------------------------------------------------------
  key:  (a) Ratio
        (b) num_trials
        (c) Row-size of matrix;
 value:
         (a) Over and underdetermined cases: Ratio > 1 vs. Ratio < 1
         (b) number of sample loss curves generated
         (c) tuple containing:
          1. numpy array: Mean_curve (take the mean of the sample loss curves)
          2. numpy array: Confidence Interval generated by sample loss curves
          3. 10th percentile of sample loss curves
          4. 90th percentile of sample loss curves
--------------------------------------------------------
'''

curve_stats = {
    "Ratio": 2.00,
    "num_trials": 30,
    64 : [[],[], [], []],
    256: [[],[], [], []],
    1024: [[],[], [], []],
    4096: [[],[], [], []],
}

#populate sample curve statistics
for n in [64, 256, 1024,4096]:
  ratio = curve_stats["Ratio"]
  num_trials = curve_stats["num_trials"]
  d = int(n * ratio)


  curr_samplecurves = np.array(samplelossCurve(num_trials, n, d, zeta))
  curr_samplecurves_mean = np.mean(curr_samplecurves, axis=0)
  curr_samplecurves_std = np.std(curr_samplecurves, axis=0, ddof=1)

  #90% confidence interval
  ci = 1.96*(1/np.sqrt(num_trials))*curr_samplecurves_std
  #quantile
  lq = np.quantile(curr_samplecurves, 0.1, axis = 0)
  hq = np.quantile(curr_samplecurves, .90, axis = 0)


  curve_stats[n][0] = curr_samplecurves_mean
  curve_stats[n][1] = ci
  curve_stats[n][2] = lq
  curve_stats[n][3] = hq

  

cs = plt.get_cmap("viridis")
plt.figure(figsize = (16,10.0))
plt.ylabel("Function Values", fontsize = '40')
plt.xlabel("Iterations", fontsize = '40')
plt.yticks(fontsize = '35')
plt.xticks(fontsize = '35')
plt.yscale("log")


legend_title = []
i = 0
for n in [64,256,1024,4096]:
  plt.fill_between(np.arange(max_iter), (curve_stats[n][2]),
                (curve_stats[n][3]), facecolor=cs(i/3), label = "n = " + str(n))
  i = i + 1




volterraMP_values = volterra_Psi(max_iter=max_iter, Delta=Delta, gamma=learning_rate, zeta=zeta, ratio=curve_stats["Ratio"], R=R,
                                            R_tilde=R_tilde)
plt.plot(volterraMP_values, c = "r", label = "Volterra, " + r"$\psi$", linewidth=5.5)
legend = plt.legend(loc = "upper right", fontsize='30')
legend.get_lines()[0].set_linewidth(15)

plt.savefig("volterraConcentration.pdf", transparent = True)
plt.show()



